Most properties of physical systems are expressed in terms of geometric tensors. Scalars (mass), vectors (velocities, forces, polarizations), matrices (polarizability, moment of inertia) and higher rank tensors are all geometric tensors.
Geometric tensors are commonly expressed with Cartesian indicies $(x, y, z)$ -- we will call these Cartesian tensors. However, there is an equally expressive way of representing geometric tensors as tensors in the irreducible representation basis (irrep tensors).
Whereas for Cartesian tensors the indices can be interpreted as information along $(x, y, z)$, irrep tensors are indexed by which irreducible representation (irrep) of $O(3)$ they are associated with. You can always convert between Cartesian and irrep bases.
The irreps of 3D rotations, the group $SO(3)$, are indexed by their angular frequency $L$. Spherical harmonics are the basis functions of the irreps of $SO(3)$; they transform in the same way as the irreps of SO(3).
Wikipedia has a great overview of spherical harmonics. As a quick recap, the spherical harmonics are the Fourier basis for functions on the unit sphere. They have two indices, most commonly called the "degree" $L$ and "order" $m$ and are commonly parameterized by spherical coordinate angles $\theta$ and $\phi$.
$Y_{l}^{m}(\theta, \phi)$ for complex spherical harmonics or $Y_{lm}(\theta, \phi)$ for real spherical harmonics.
In se3cnn, we use real spherical harmonics. There are $2 L + 1$ functions (indexed by $m$) for each $L$. Functions of degree $L$ have the same frequency. Note, that these frequencies must be integral (or half-integral for $SU(2)$) because of the periodic boundary conditions of the sphere.
The irreps of 3D rotations and inversion $(x, y, z) \rightarrow (-x, -y, -z)$, the group $O(3)$, are indexed and their angular frequency $L$ and their parity $p$. Spherical harmonics (which transform as irreps of $SO(3)$ have definite parity (or behavior under inversion $(x, y, z) \rightarrow (-x, -y, -z)$): odd parity (flips sign under inversion) for odd $L$ and even parity (does not flip sign under inversion) for even $L$.
We use irrep tensors in our network because our convolutional filters are expressed in terms of irreps. Our filters are based on spherical harmonics and we can additionally specify the parity our filters have. For this reason it is more precise to say we use irrep tensors than spherical harmonic tensors since spherical harmonics by themselves have a specified parity.
We use the term spherical tensor as irrep tensors that have at most multiplicity (copy) of each irrep, such that it can be visualized on the sphere.
se3cnn¶To keep track of which spherical tensor entries correspond to which spherical harmonic, we use representation lists, commonly saved as a variable Rs.
Rs is a list of tuples (mult, L, p) where mult is the multiplicity (or number of copies), L is the degree of the spherical harmonic, and p is the parity. Parity is -1 for odd, 1 for even, and 0 if you only want to use irreps of $SO(3)$ rather than $O(3)$. In most of this tutorial, we default to parity set to 0 and only deal with irreps $SO(3)$.
For example, the Rs of a single vector is
Rs_vec = [(1, 1)]
and two vectors
Rs_2vec = [(2, 1)]
First, let's draw the spherical harmonics using the SphericalTensor class defined e3nn.tensor.spherical_tensor. This is a handy helper class that I wrote for this tutorial so we can quickly manipulate and plot spherical tensors.
import torch
import numpy as np
import e3nn.o3 as o3
import e3nn.rs as rs
from e3nn.tensor.spherical_tensor import SphericalTensor
import plotly
from plotly.subplots import make_subplots
import plotly.graph_objects as go
torch.set_default_dtype(torch.float64)
L_max = 3
rows = L_max + 1
cols = 2 * L_max + 1
specs = [[{'is_3d': True} for i in range(cols)]
for j in range(rows)]
fig = make_subplots(rows=rows, cols=cols, specs=specs)
for L in range(L_max + 1):
for m in range(0, 2 * L + 1):
tensor = torch.zeros((L + 1)**2)
tensor[L**2 + m] = 1.0
sphten = SphericalTensor(tensor)
row, col = L + 1, (L_max - L) + m + 1
r, f = sphten.plot(relu=False, radius=True)
r, f = r.numpy(), f.numpy()
trace = go.Surface(x=r[..., 0], y=r[..., 1], z=r[..., 2], surfacecolor=f)
if m != 2 * L_max + 1:
trace.showscale=False
fig.add_trace(trace, row=row, col=col)
fig.show()